package com.xl.competition.old.task1

import org.apache.log4j.{Level, Logger}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

import java.util
import scala.collection.mutable

/**
 * @author: xl
 * @createTime: 2023/10/6 10:58:40
 * @program: com.xl.competition
 * @description:
 */
object SparkToMatrix {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.WARN)
    val spark: SparkSession = SparkSession.builder()
      .master("local[1]")
      .appName("MatrixConversion")
      .getOrCreate()
    import spark.implicits._
    // 假设您已经将fact_orders表加载到DataFrame中
    val factOrdersDF: DataFrame = spark.read.option("header", value = true).csv("F:\\IDEA_BigData\\com.xl.competition\\competition.spark\\src\\main\\resources\\fact_orders.csv")
    // 创建一个包含所有唯一用户的列表
    val users: Array[String] = factOrdersDF
      .select("consumer")
      .distinct()
      .rdd
      .map(r => r(0).toString)
      .collect()
      .sorted
    val seq: Seq[(Int, Int, Int, Int, Int)] = Seq.fill(users.length)(0, 0, 0, 0, 0)
    var matrix: DataFrame = spark.sparkContext.makeRDD(seq.toList).toDF(users: _*)

    var i = 0;
    val value: RDD[Row] = matrix.repartition(1).rdd.map(item => {
      val rows: Seq[Any] = item.toSeq
      val str: String = users(i)
      i += 1
      Row.fromSeq(str +: rows)
    })

    val schema: StructType = matrix.schema
    val field: StructField = StructField("", StringType)
    val a: Seq[StructField] = field +: schema
    val structType = new StructType(a.toArray)
    matrix = spark
      .createDataFrame(value, structType)
    val map = new util.HashMap[String, Integer]()
    for (i <- users.indices) {
      map.put(users(i), i + 1)
    }

    // 填充矩阵，查找用户之间购买的相同零件
    factOrdersDF.rdd.collect.foreach(item => {
      val user: String = item.getAs[String]("consumer")
      val goodsId: String = item.getAs[String]("goods_id")
      val list: List[String] = factOrdersDF.filter($"consumer" =!= user and ($"goods_id") === goodsId)
        .select("consumer")
        .collect()
        .map(_.getString(0))
        .toList
      changOne(user, list)
    })


    def changOne(user: String, relation: List[String]): Unit = {
      if (relation.isEmpty || relation == null) {
        return
      }
      val alterValueRdd: RDD[Row] = matrix.rdd.map((item: Row) => {
        //  A B C D E
        //A 0 0 0 0 0
        val rows: mutable.Seq[Any] = item.toSeq.toBuffer
        if (rows.head.equals(user)) {
          for (elem <- relation) {
            rows(map.get(elem)) = 1
          }
        }
        Row.fromSeq(rows.toSeq)
      })
      matrix = spark.createDataFrame(alterValueRdd, structType)
    }

    matrix.show()
    // 将矩阵保存为txt文件
    spark.stop()
  }
}
